Skip to content

[Multidevice] Add TMA bulk copy kernel and P2P transport option#6012

Open
samnordmann wants to merge 1 commit intotma_p2pfrom
tma_integration
Open

[Multidevice] Add TMA bulk copy kernel and P2P transport option#6012
samnordmann wants to merge 1 commit intotma_p2pfrom
tma_integration

Conversation

@samnordmann
Copy link
Collaborator

  • Add a Hopper TMA (cp.async.bulk) copy kernel (csrc/multidevice/tma_copy.cu) compiled at runtime via NVRTC, and wire it as an alternative P2P data transport alongside the existing copy-engine (cudaMemcpyAsync) path.
  • Add P2pTransport option (NVFUSER_ENABLE=p2p_transport(tma)) that switches sendPost/recvPost in cuda_p2p.cpp between copy-engine (default) and TMA.

@github-actions
Copy link

Description

  • Add Hopper TMA (Tensor Memory Accelerator) bulk copy kernel compiled at runtime via NVRTC

  • Add P2pTransport option (NVFUSER_ENABLE=p2p_transport(tma)) to switch between copy-engine and TMA

  • Integrate TMA transport into existing sendPost/recvPost functions in cuda_p2p.cpp

  • Simplify and improve test coverage for TMA copy across local, P2P, and multicast scenarios

Changes walkthrough

Relevant files
Enhancement
cuda_p2p.cpp
Add TMA kernel compilation and P2P transport integration 

csrc/multidevice/cuda_p2p.cpp

  • Added TMA copy kernel compilation and launch logic with NVRTC
  • Added getP2pTransport() function to read NVFUSER_ENABLE option
  • Modified recvPost() and sendPost() to conditionally use TMA vs
    copy-engine
  • Added operator<< for P2pTransport enum
  • Implemented chunked TMA copy handling for large transfers via shared
    memory
  • +142/-12
    cuda_p2p.h
    Add TMA transport declarations and enum                                   

    csrc/multidevice/cuda_p2p.h

  • Added P2pTransport enum with CopyEngine and Tma options
  • Added launchTmaCopy() function declaration
  • Added getP2pTransport() and operator<< declarations
  • +14/-2   
    Configuration changes
    options.cpp
    Register p2p_transport enable option                                         

    csrc/options.cpp

    • Added "p2p_transport" option mapping to EnableOption::P2pTransport
    +1/-0     
    options.h
    Add P2pTransport to enable options enum                                   

    csrc/options.h

    • Added P2pTransport to EnableOption enum
    +1/-0     
    Tests
    test_multidevice_tma.cpp
    Simplify TMA tests using production kernel launcher           

    tests/cpp/test_multidevice_tma.cpp

  • Removed NVRTC compilation helpers, now uses production launchTmaCopy()
  • Simplified tests for local device, P2P, and multicast TMA copy
    scenarios
  • Maintained test coverage for TMA bulk copy functionality
  • +10/-122

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Thread Safety

    The static module and kernel variables in launchTmaCopy() are not thread-safe. Multiple threads could simultaneously enter the initialization block, potentially causing race conditions during NVRTC compilation and CUDA module loading. Consider adding mutex protection or std::call_once for thread-safe initialization.

    static CUmodule module = nullptr;
    static CUfunction kernel = nullptr;
    
    if (module == nullptr) {
      nvrtcProgram prog;
      NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram(
          &prog,
          nvfuser_resources::tma_copy_cu,
          "tma_copy.cu",
          0,
          nullptr,
          nullptr));
    
      int device = 0;
      NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDevice(&device));
      cudaDeviceProp prop;
      NVFUSER_CUDA_RT_SAFE_CALL(
          cudaGetDeviceProperties(&prop, device));
    
      NVF_CHECK(
          prop.major >= 9,
          "TMA transport requires Compute Capability >= 9.0 (Hopper+). "
          "Current device ",
          device,
          " is Compute Capability ",
          prop.major,
          ".",
          prop.minor);
    
      std::string arch_arg = "--gpu-architecture=compute_" +
          std::to_string(prop.major) + std::to_string(prop.minor);
      std::vector<const char*> opts = {
          arch_arg.c_str(), "--std=c++17"};
    
      nvrtcResult res =
          nvrtcCompileProgram(prog, (int)opts.size(), opts.data());
      if (res != NVRTC_SUCCESS) {
        size_t logSize;
        NVFUSER_NVRTC_SAFE_CALL(
            nvrtcGetProgramLogSize(prog, &logSize));
        std::vector<char> log(logSize);
        NVFUSER_NVRTC_SAFE_CALL(
            nvrtcGetProgramLog(prog, log.data()));
        NVF_ERROR(
            false, "TMA kernel compilation failed:\n", log.data());
      }
    
      size_t ptxSize;
      NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptxSize));
      std::vector<char> ptx(ptxSize);
      NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTX(prog, ptx.data()));
      NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog));
    
      NVFUSER_CUDA_SAFE_CALL(
          cuModuleLoadData(&module, ptx.data()));
      NVFUSER_CUDA_SAFE_CALL(
          cuModuleGetFunction(&kernel, module, "tma_copy_1d"));
    }
    Resource Cleanup

    If NVRTC compilation fails (lines 398-407), the nvrtcProgram is destroyed but there's no cleanup path if cuModuleLoadData or cuModuleGetFunction fail. Consider adding proper error handling with nvrtcDestroyProgram in all error paths to prevent resource leaks.

    nvrtcResult res =
        nvrtcCompileProgram(prog, (int)opts.size(), opts.data());
    if (res != NVRTC_SUCCESS) {
      size_t logSize;
      NVFUSER_NVRTC_SAFE_CALL(
          nvrtcGetProgramLogSize(prog, &logSize));
      std::vector<char> log(logSize);
      NVFUSER_NVRTC_SAFE_CALL(
          nvrtcGetProgramLog(prog, log.data()));
      NVF_ERROR(
          false, "TMA kernel compilation failed:\n", log.data());
    }
    
    size_t ptxSize;
    NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptxSize));
    std::vector<char> ptx(ptxSize);
    NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTX(prog, ptx.data()));
    NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog));
    
    NVFUSER_CUDA_SAFE_CALL(
        cuModuleLoadData(&module, ptx.data()));
    NVFUSER_CUDA_SAFE_CALL(
        cuModuleGetFunction(&kernel, module, "tma_copy_1d"));
    Performance Validation

    The PR lacks performance comparison data between TMA and copy engine transports. Consider adding benchmarking results or performance metrics to validate that TMA provides expected benefits over the default copy engine, especially for different transfer sizes and patterns.

          if (getP2pTransport() == P2pTransport::Tma) {
            launchTmaCopy(
                ipc_handles.local().ptr(),
                ipc_handles.peer().ptr(),
                count,
                stream);
          } else {
            NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync(
                ipc_handles.local().ptr(),
                ipc_handles.peer().ptr(),
                count,
                cudaMemcpyDeviceToDevice,
                stream));
          }
          // Signals completion
          WriteValue32ToLocalAndPeer(stream, ipc_handles, IpcSemaphore::kIdle);
          break;
        }
        case P2pProtocol::Put: {
          WriteValue32ToLocalAndPeer(
              stream, ipc_handles, IpcSemaphore::kInProgress);
          break;
        }
        default:
          NVF_ERROR("Invalid P2P protocol: ", protocol);
      }
    }
    
    void recvWait(const P2pIpcHandle& ipc_handles, CUstream stream) {
      P2pProtocol protocol = getP2pProtocol();
      switch (protocol) {
        case P2pProtocol::Put:
          NVFUSER_CUDA_SAFE_CALL(cuStreamWaitValue32(
              stream,
              reinterpret_cast<CUdeviceptr>(ipc_handles.local().semaphore()),
              (cuuint32_t)(IpcSemaphore::kIdle),
              CU_STREAM_WAIT_VALUE_EQ));
          break;
        case P2pProtocol::Get:
          break;
        default:
          NVF_ERROR("Invalid P2P protocol: ", protocol);
      }
    }
    
    void sendPost(const P2pIpcHandle& ipc_handles, int64_t count, CUstream stream) {
      P2pProtocol protocol = getP2pProtocol();
      switch (protocol) {
        case P2pProtocol::Get:
          // signal to self and peer that transfer is in progress
          WriteValue32ToLocalAndPeer(
              stream, ipc_handles, IpcSemaphore::kInProgress);
          break;
        case P2pProtocol::Put: {
          // wait for receiver to be ready
          NVFUSER_CUDA_SAFE_CALL(cuStreamWaitValue32(
              stream,
              reinterpret_cast<CUdeviceptr>(ipc_handles.local().semaphore()),
              (cuuint32_t)(IpcSemaphore::kInProgress),
              CU_STREAM_WAIT_VALUE_EQ));
          // Put the data to the receiver
          if (getP2pTransport() == P2pTransport::Tma) {
            launchTmaCopy(
                ipc_handles.peer().ptr(),
                ipc_handles.local().ptr(),
                count,
                stream);
          } else {
            NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync(
                ipc_handles.peer().ptr(),
                ipc_handles.local().ptr(),
                count,
                cudaMemcpyDeviceToDevice,
                stream));
          }

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 25, 2026

    Greptile Summary

    Adds Hopper TMA (cp.async.bulk) as an alternative P2P transport alongside the existing copy-engine path, controlled by NVFUSER_ENABLE=p2p_transport(tma).

    Key changes:

    • Implements launchTmaCopy() that compiles tma_copy.cu at runtime via NVRTC and handles arbitrary sizes via chunking
    • Adds P2pTransport enum (CopyEngine, Tma) and integrates transport selection into sendPost/recvPost
    • Enforces Compute Capability >= 9.0 check and 16-byte alignment requirement
    • Simplifies test suite by removing duplicate test-only TMA implementation

    Issues found:

    • Thread-safety issue in static initialization of TMA kernel module (also affects existing kernels)

    Confidence Score: 3/5

    • Safe for merging with awareness of pre-existing thread-safety pattern
    • The thread-safety issue in static initialization is a real concern, but it follows the same pattern as existing kernels in the file (launchAlltoallvKernel, launchMulticastKernel). If those haven't caused issues in practice, this likely won't either. The implementation is well-structured with proper error handling and alignment checks.
    • Pay attention to csrc/multidevice/cuda_p2p.cpp due to the static initialization pattern

    Important Files Changed

    Filename Overview
    csrc/multidevice/cuda_p2p.cpp Adds TMA copy kernel compilation and transport switching logic; potential thread-safety issue in static initialization
    csrc/multidevice/cuda_p2p.h Adds P2pTransport enum and launchTmaCopy declaration; clean interface additions
    csrc/options.cpp Adds p2p_transport option to enable map; straightforward configuration change
    csrc/options.h Adds P2pTransport enum value with documentation; clean enum addition
    tests/cpp/test_multidevice_tma.cpp Simplifies tests to use production launchTmaCopy; removes duplicate test-only implementation

    Flowchart

    %%{init: {'theme': 'neutral'}}%%
    flowchart TD
        A[P2P Send/Recv Operation] --> B{getP2pTransport}
        B -->|CopyEngine default| C[cudaMemcpyAsync]
        B -->|Tma NVFUSER_ENABLE=p2p_transport tma| D[launchTmaCopy]
        
        D --> E{Module initialized?}
        E -->|No| F[Compile tma_copy.cu via NVRTC]
        F --> G[Check SM >= 9.0]
        G --> H[Cache CUmodule & CUfunction]
        E -->|Yes| I[Use cached kernel]
        H --> I
        
        I --> J{Size > 48KB chunk?}
        J -->|Yes| K[Split into chunks]
        J -->|No| L[Single launch]
        K --> M[Launch multiple TMA kernels]
        L --> N[Launch TMA kernel]
        M --> O[GMEM -> SMEM -> GMEM]
        N --> O
        
        C --> P[Direct device-to-device copy]
        O --> Q[Complete]
        P --> Q
    
    Loading

    Last reviewed commit: 13f72a5

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    5 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Comment on lines +362 to +363
    static CUmodule module = nullptr;
    static CUfunction kernel = nullptr;
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Static initialization lacks thread-safety protection. Multiple threads calling launchTmaCopy concurrently could race on the module == nullptr check (line 365), causing duplicate compilations or accessing partially-initialized state.

    Other kernels in this file (launchAlltoallvKernel, launchMulticastKernel) have the same pattern. Consider adding mutex protection or using std::call_once for thread-safe lazy initialization if concurrent calls are possible.

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant